Skip to main content

Self-play

The policy is a custom CNN policy made using PyTorch:

class CustomCNN(BaseFeaturesExtractor):
def __init__(self, observation_space: gym.spaces.Box, features_dim: int=128):
super(CustomCNN, self).__init__(observation_space, features_dim)
n_input_channels = observation_space.shape[0]
self.cnn = nn.Sequential(
nn.Conv2d(n_input_channels, 32, kernel_size=3, stride=1, padding=0),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=0),
nn.ReLU(),
nn.Flatten(),
)

with th.no_grad():
n_flatten = self.cnn(
th.as_tensor(observation_space.sample()[None]).float()
).shape[1]

self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())

def forward(self, observations: th.Tensor) -> th.Tensor:
return self.linear(self.cnn(observations))

policy_kwargs = dict(
features_extractor_class=CustomCNN,
)

Initially, the policy is initialized and trained against the alphabeta_agent for a number of timesteps.

initial_policy = PPO('CnnPolicy', env, policy_kwargs=policy_kwargs)
initial_policy.learn(total_timesteps=1000, progress_bar=True)

initial_policy.save("initial_policy")

Subsequently, for 10 iterations, a new environment is initialized with the previously trained policy as the opponent and a new policy is trained against it. The final trained policy is thus very good at playing against its immediate predecessor, which in turn is good at playing against its predecessor, and so on, until the first policy which is good at playing against alphabeta_agent.

def agent(obs, config, policy):
valid_moves = [col for col in range(config.columns) if obs.board[col] == 0]
winning_moves = [move for move in valid_moves if check_winning_move(obs, config, move, obs.mark)]
if winning_moves:
return winning_moves[0]
losing_moves = [move for move in valid_moves if check_winning_move(obs, config, move, 3 - obs.mark)]
if losing_moves:
return losing_moves[0]
col, _ = policy.predict(np.array(obs['board']).reshape(1, 6, 7))
is_valid = (obs['board'][int(col)] == 0)
if is_valid:
return int(col)
else:
return random.choice([col for col in range(config.columns) if obs.board[int(col)] == 0])

policies = [initial_policy]
for i in range(10):
def agent_policy(obs, config):
return agent(obs, config, policies[i])

env = ConnectFourGym(agent_policy=agent_policy)
policy = PPO('CnnPolicy', env, policy_kwargs=policy_kwargs)
policy.learn(total_timesteps=10000, progress_bar=True)

policies.append(policy)